[Perf][3/n] Eliminate GPU<->CPU syncs in attention impls#41434
Conversation
There was a problem hiding this comment.
Claude Code Review
This repository is configured for manual code reviews. Comment @claude review to trigger a review and subscribe this PR to future pushes, or @claude review once for a one-time review.
Tip: disable this comment in your organization's Code Review settings.
There was a problem hiding this comment.
Code Review
This pull request introduces a series of optimizations across various attention backends to minimize CPU-GPU synchronization. The changes focus on using CPU-resident metadata for sequence length calculations, replacing synchronizing operations such as torch.nonzero and torch.bincount with asynchronous equivalents, and utilizing slice-based assignments to avoid implicit synchronizations. Furthermore, the async_tensor_h2d utility was enhanced to facilitate non-blocking host-to-device transfers. I have no feedback to provide.
|
@claude review |
There was a problem hiding this comment.
Performance-focused PR eliminating GPU↔CPU syncs across 7 attention backends; changes look correct and well-documented, but the breadth of critical paths touched (including subtle rewrites of torch.nonzero, torch.bincount, and mask-mod construction) warrants human review.
Extended reasoning...
Overview
This PR eliminates unnecessary GPU↔CPU synchronization in attention metadata builders and impls across FlashInfer, FlexAttention, Mamba, Tree, Triton, and TurboQuant backends, plus shared utils.py, buffer_utils.py, and penalties.py. It also renames async_tensor_h2d's target_device parameter to device and adds a module-level PIN_MEMORY constant. The recurring patterns are: (a) tensor[0] = x → tensor[:1] = x / .fill_() to avoid scalar-assignment sync, (b) precomputing max() / max_seq_len / max_query_len on CPU instead of .max().item(), (c) building list-shaped tensors via pinned async_tensor_h2d rather than torch.tensor(.., device=cuda), and (d) replacing data-dependent ops (torch.nonzero, torch.bincount, repeat_interleave of GPU tensors) with sync-free equivalents.
Security risks
None. This is a pure performance optimization — no auth, crypto, permissions, network, or input-handling code is touched. All changes are local to GPU kernel orchestration.
Level of scrutiny
Medium-high. The PR is performance-only and the bug hunting system found no issues, but it touches the hot path of every supported attention backend on a v1/nvidia-labeled PR. The non-mechanical changes deserve a human eye:
unique_static_unsortedrewrite inflex_attention.py: routes non-kept entries to a garbage columnNthen slices it off — equivalence depends onscatter_correctly handling duplicates by last-write-wins, which is fine for uniquedest_posbut the construction relies oncumsum-1producing unique indices for kept entries plus the constantNfor non-kept. Looks correct, but is subtle.torch.bincount→scatter_add_inutils.py: the rewrite is functionally equivalent assumingrequest_idsvalues are bounded bynum_reqs, which the surrounding code guarantees viabucketize.- mamba_attn's switch from
compute_num_computed_tokens().cpu()toseq_lens_cpu_upper_bound[…] - prefill_query_lens_cpurelies on the documented contract that the upper bound is precise for prefill rows; verified against the field docstring atvllm/v1/attention/backend.py:405-408. - tree_attn's precomputed
max_query_len_prefill/decodeandmax_seq_len_prefill/decode: needs verification that these flow correctly into the cachedprefill_metadata/decode_metadataproperties, which they do. - The mask-mod simplification in flex (
torch.where(is_valid, .., False)→is_valid & ..) changes evaluation semantics — the inner mask_mod is now always invoked with possibly-out-of-range logical indices. Likely fine because mask_mod implementations are pure index comparisons, but worth confirming for non-trivial mask_mods.
Other factors
- Maintainer
njhillexplicitly requested review. - This is part of a series ([Perf][3/n]) — prior PRs in the series presumably established the patterns being applied here.
- Each change is well-commented in-line, explaining the why (which sync was forced and why the new form avoids it). That makes review tractable.
- I have no prior reviews on this PR.
Signed-off-by: Nick Hill <nickhill123@gmail.com>
There was a problem hiding this comment.
I didn't find any bugs, but this is a broad performance refactor across 7 attention backends with subtle correctness implications (replacing torch.bincount with scatter_add_, using seq_lens_cpu_upper_bound as a max-seq-len proxy in mamba/tree/turboquant, slice-assignment to dodge scalar-write syncs), so I'd like a human to give it a look despite the explicit perf wins shown in the description.
Extended reasoning...
Overview
This PR eliminates several GPU↔CPU synchronization points across attention backends and a couple of supporting modules:
vllm/utils/torch_utils.py: Adds a module-levelPIN_MEMORYconstant and renamesasync_tensor_h2dstarget_deviceparameter todevicewith a defaultpin_memory=PIN_MEMORY.vllm/v1/attention/backends/flashinfer.py: Replaces a scalar-write (paged_kv_indptr_prefill_gpu[0] = 0) with a slice assignment.vllm/v1/attention/backends/flex_attention.py: Reworks_offsets_to_doc_ids_tensorto build on CPU and async-upload, plumbs a newquery_start_loc_cpufield throughFlexAttentionMetadata, and rewritesunique_static_unsortedto use a scatter rather thantorch.nonzero.vllm/v1/attention/backends/mamba_attn.py: Derives prefill context lengths from CPU upper-bound seq_lens and switches list→tensor uploads toasync_tensor_h2d.vllm/v1/attention/backends/tree_attn.py: Precomputesmax_query_len/max_seq_lenfor prefill/decode sub-batches on CPU in the builder, eliminating per-forward.max().item()syncs.vllm/v1/attention/backends/triton_attn.py: Switchescompute_mm_prefix_range_tensorto async H2D.vllm/v1/attention/backends/turboquant_attn.py: Adds CPU copies ofquery_start_loc/seq_lensto metadata to avoid.tolist()/.max()syncs in the prefill path.vllm/v1/attention/backends/utils.py: Replacestorch.bincountwithscatter_add_inmake_kv_sharing_fast_prefill_common_attn_metadataand pushes index tensors to device up front inmake_local_attention_virtual_batches.vllm/v1/worker/gpu/{buffer_utils,sample/penalties}.py: Updates calls to use the renamedasync_tensor_h2dkeyword.
Security risks
None — this is a pure-performance refactor. No auth, crypto, network, or input-validation surface is touched.
Level of scrutiny
Higher than a typical perf change. Although each individual diff is mechanical, several rely on non-obvious equivalences (the seq_lens_cpu_upper_bound substitution being precise for prefills and merely safely-optimistic for decodes; scatter_add_ semantics matching torch.bincount with minlength=; repeat_interleave on CPU vs GPU producing identical orderings) that benefit from a human familiar with the v1 attention stack confirming no edge case slips through.
Other factors
- A vLLM committer (njhill) explicitly tagged me to review — the maintainers are aware this needs eyes.
- The PR is labelled
readyand includes a benchmark showing ~5% throughput improvement on TurboQuant. - The bug-hunting system found no defects, and the comments throughout the diff explain why each rewrite preserves correctness, which is a good sign.
| prefill_start : num_reqs + 1 | ||
| ] | ||
| paged_kv_indptr_prefill_gpu[0] = 0 | ||
| # Assign to slice to avoid cpu sync. |
There was a problem hiding this comment.
a lot of real black magic in this pr
There was a problem hiding this comment.
cuda_tensor[0] = 0 uses copy_ which does a sync, cuda_tensor[:1] = 0 uses fill_ which doesn't :)
| decode_max_query_len = int(num_decode_tokens.max().item()) | ||
| total_num_decode_tokens = int(num_decode_tokens.sum().item()) |
There was a problem hiding this comment.
Possibly but it would require more significant rework I think.
For now I am opening a series of PRs with "low hanging" fixes. Remaining syncs can be wrapped in the gpu_sync_allowed() context manager when #40561 is merged and we'll at least know where they are and can decide if/when to put in additional work to address them.
The test patches `torch.fx.experimental.symbolic_shapes.make_symbol` in the parent process and counts via a `multiprocessing.Value`. In V1 the actual compile runs inside an `EngineCore` subprocess that vLLM spawns whenever CUDA is initialized in the parent (via `_maybe_force_spawn`), so the monkey-patch never sees the compile path and the counter stays at 0. This is a structural test-infra issue, not a regression: CI flagged it on the build for vllm-project#41434, but the same failure reproduces on its parent commit and is unrelated to that PR's attention-impl changes. Replace the brittle torch-internal monkey-patch with the existing `compilation_counter.expect(...)` pattern already used by `test_aot_counters_on_save_and_load`. Force `VLLM_ENABLE_V1_MULTIPROCESSING=0` so the singleton counter is incremented in the same process that runs the assertions; the cache code path itself is identical in-process vs subprocess. Add `cleanup_dist_env_and_memory()` between the two `LLM(...)` instances and lower `gpu_memory_utilization` to 0.1 to leave headroom on the second instantiation. The activation-registry reset is preserved: without it, GPT-2's `gelu_new` op leaves `disabled_custom_ops` mutated, the AOT cache hash shifts between phases, and `VLLM_FORCE_AOT_LOAD=1` raises FileNotFoundError. Signed-off-by: haosdent <haosdent@gmail.com>
…t#41434) Signed-off-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Libin Tang <libin.tang@intel.com>
Unnecessary gpu/cpu syncs in attention implementations, found via #40561.
TurboQuant benchmark
Each scenario runs vLLM with
--tensor-parallel-size 1 --distributed-executor-backend uni(UniProcExecutor) on a single NVIDIA GB200 GPU. Model:Qwen/Qwen3-0.6B. Each side (without / with change) is the mean ± population std across 3 timed runs sharing one server process; each run uses its own seed (1, 2, 3) and is preceded by a fresh warmup batch. Δ = relative change of with-mean vs. without-mean (✓ = improvement, ✗ = regression).TurboQuant c=32, 256 in / 2048 out
Per-side metadata:
Without change:
With change: